import hashlib
import json
import os
import requests
import zipfile
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

download_folder = r'./DepthTrack'  # 下载文件夹路径

def download_file(file_info):
    file_name = file_info['name']
    url = file_info['url']
    file_path = os.path.join(download_folder, file_name)
    try:
        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))
        block_size = 1024  # 1 Kibibyte
        progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True, desc=file_name)

        with open(file_path, 'wb') as file:
            for data in response.iter_content(block_size):
                progress_bar.update(len(data))
                file.write(data)
        progress_bar.close()
        return file_name, True
    except Exception as e:
        print(f"Error downloading {file_name}: {e}")
        return file_name, False

def download_files(file_infos, max_workers=5):
    print(f'start downloading {len(file_infos)} files')
    results = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_file = {executor.submit(download_file, file_info): file_info for file_info in file_infos}
        for future in as_completed(future_to_file):
            file_info = future_to_file[future]
            try:
                file_name, success = future.result()
                results.append((file_name, success))
            except Exception as exc:
                print(f"{file_info['name']} error: {exc}")
    return results

def generate_download_items(json_record: int | list) -> list:
    if isinstance(json_record, list):
        download_items = []
        for record in json_record:
            download_items.extend(generate_download_items(record))
        return download_items

    json_url = f'https://zenodo.org/records/{json_record}/export/json'
    save_to = os.path.join(download_folder, f'{json_record}.json')
    if not os.path.exists(save_to):
        response = requests.get(json_url, stream=True)
        if response.status_code == 200:
            with open(save_to, 'wb') as file:
                for chunk in response.iter_content(chunk_size=1024):
                    if chunk:
                        file.write(chunk)
        else:
            raise Exception(f"Failed to download file from {json_url}")
    with open(save_to, 'r', encoding='utf-8') as file:
        data = json.load(file)
    # 生成下载列表
    download_items = []
    entries = data['files']['entries']
    for entry in entries:
        dic = entries[entry]
        item = {
            'name': dic['key'],
            'url': dic['links']['content'],
            'size': dic['size'],
            'md5': dic['checksum'][4:]
        }
        download_items.append(item)
    return download_items

def generate_split_txt(download_item: list, item_type: str) -> None:
    txt_file_path = os.path.join(download_folder, f'depthtrack_{item_type}.txt')
    if not os.path.exists(txt_file_path):
        with open(txt_file_path, 'w', encoding='utf-8') as file:
            file.write('\n'.join([item['name'].split('.')[0] for item in download_item]))

def check_exists(download_items: list) -> list:
    exist_item_list = [item for item in download_items if os.path.exists(os.path.join(download_folder, item['name']))]
    need_to_download = [item for item in download_items if item not in exist_item_list]
    print(f'find {len(exist_item_list)} exist files')
    if len(exist_item_list) == 0:
        return download_items
    with tqdm(total=len(exist_item_list), desc='checking exist') as pbar:
        correct_zip = 0
        wrong_zip = 0
        for item in exist_item_list:
            file_path = os.path.join(download_folder, item['name'])
            if item['md5'] == hashlib.md5(open(file_path, 'rb').read()).hexdigest():
                correct_zip += 1
            else:
                need_to_download.append(item)
                wrong_zip += 1
            pbar.update(1)
            pbar.desc = f'correct: {correct_zip}, wrong: {wrong_zip}'
    print(f'remove {correct_zip} files')
    return need_to_download

def task_download(total_download_items) -> list:
    need_to_download = check_exists(total_download_items)
    download_files(need_to_download)
    return need_to_download

def task_check_zip(check_download_items):
    with tqdm(total=len(check_download_items), desc='checking files') as pbar:
        correct_zip = 0
        wrong_zip = 0
        for item in check_download_items:
            file_path = os.path.join(download_folder, item['name'])
            if os.path.exists(file_path) and item['md5'] == hashlib.md5(open(file_path, 'rb').read()).hexdigest():
                check_download_items.remove(item)
                correct_zip += 1
            else:
                wrong_zip += 1
            pbar.update(1)
            pbar.desc = f'correct: {correct_zip}, wrong: {wrong_zip}'
    return check_download_items

def task_unzip(download_items):
    wrong_items = []
    success_num = 0
    with tqdm(total=len(download_items), desc='unzip files') as pbar:
        for item in download_items:
            pbar.desc = f'success: {success_num}, unziping: {item["name"]}'
            file_path = os.path.join(download_folder, item['name'])
            try:
                with zipfile.ZipFile(file_path, 'r') as zip_ref:
                    zip_ref.extractall(download_folder)
                success_num += 1
            except:
                wrong_items.append(item)
                print(f'failed to unzip {file_path}')
            pbar.update(1)
    print(f'{len(wrong_items)} files failed to unzip:')
    for item in wrong_items:
        print(item['name'])


if __name__ == "__main__":
    auto_unzip = True

    json_train_records = [5794115, 5837926]
    json_val_records = [5792146]

    print(f'generate download items')
    download_item_train = generate_download_items(json_train_records)
    download_item_val = generate_download_items(json_val_records)
    generate_split_txt(download_item_train, 'train')
    generate_split_txt(download_item_val, 'val')

    total_download_items = [*download_item_train, *download_item_val]
    total_download_items.sort(key=lambda x: x['name'])
    print(f'find {len(total_download_items)} zip files')

    # start download
    need_to_download_items = task_download(total_download_items)
    need_to_download_items = task_check_zip(need_to_download_items)

    # check
    while len(need_to_download_items) > 0:
        print(f'{len(need_to_download_items)} files are wrong, try to download again')
        task_download(need_to_download_items)
        need_to_download_items = task_check_zip(need_to_download_items)

    if auto_unzip:
        print(f'start unzip')
        task_unzip(total_download_items)

    # finish
    print(f'finish')